Show and Tell Model --- Given an image, the model generates captions.
Generated captions:
The "Show and Tell" model presented in this notebook is based on work in https://github.com/tensorflow/models/tree/master/im2txt. Modifications are made to make training much faster (from one week with GPU to a few hours with CPU only). More specifically, the following modifications are made:
The following diagram illustrates the model architecture (For details, see show and tell model on github).
Send any feedback to datalab-feedback@google.com.
150 GB disk. n1-standard-1 VM is probably not enough. Recommend high-mem VM types. If you use "datalab create" command to create the Datalab instance, I would suggest high memory VMs by adding "--machine-type n1-highmem-2" option. See https://cloud.google.com/datalab/docs/how-to/machine-type for instructions.
We will use MSCOCO data. Although we only use the cats and dogs related images and captions, we need to download the zip packages with full data.
In [ ]:
# Download Images Data
!mkdir -p /content/datalab/img2txt/images
!wget -P /content/datalab/img2txt/ http://msvocds.blob.core.windows.net/coco2014/train2014.zip
!wget -P /content/datalab/img2txt/ http://msvocds.blob.core.windows.net/coco2014/val2014.zip
!unzip -q -j /content/datalab/img2txt/train2014.zip -d /content/datalab/img2txt/images
!unzip -q -j /content/datalab/img2txt/val2014.zip -d /content/datalab/img2txt/images
In [3]:
# Download Captions Data
!wget -P /content/datalab/img2txt/ http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip
!unzip -q -j /content/datalab/img2txt/captions_train-val2014.zip -d /content/datalab/img2txt/
In [1]:
from datetime import datetime
from random import randint
import os
import shutil
import six
import tempfile
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope
import yaml
In [2]:
def save_vocab(word_to_id, vocab_file):
"""Save vocabulary to file."""
with tf.gfile.Open(vocab_file, 'w') as fw:
yaml.dump(word_to_id, fw, default_flow_style=False)
def load_vocab(vocab_file):
"""Load vocabulary from file."""
with tf.gfile.Open(vocab_file, 'r') as fr:
return yaml.load(fr)
def get_instances_size(file_pattern):
"""Count training instances from tf.example file."""
c = sum(1 for x in tf.python_io.tf_record_iterator(file_pattern))
print('instances size is %d' % c)
return c
In [3]:
INCEPTION_V3_CHECKPOINT = 'gs://cloud-ml-data/img/flower_photos/inception_v3_2016_08_28.ckpt'
INCEPTION_EXCLUDED_VARIABLES = ['InceptionV3/AuxLogits', 'InceptionV3/Logits', 'global_step']
def make_batches(iterable, n):
"""Make batches with iterable."""
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]
def build_image_processing(image_str_tensor):
"""Create image-to-embeddings tf graph."""
def _decode_and_resize(image_str_tensor):
"""Decodes jpeg string, resizes it and returns a uint8 tensor."""
# These constants are set by Inception v3's expectations.
height = 299
width = 299
channels = 3
image = tf.image.decode_jpeg(image_str_tensor, channels=channels)
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, [height, width], align_corners=False)
image = tf.squeeze(image, squeeze_dims=[0])
image = tf.cast(image, dtype=tf.uint8)
return image
image = tf.map_fn(_decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.subtract(image, 0.5)
inception_input = tf.multiply(image, 2.0)
# Build Inception layers, which expect a tensor of type float from [-1, 1)
# and shape [batch_size, height, width, channels].
with tf.contrib.slim.arg_scope(inception_v3_arg_scope()):
_, end_points = inception_v3(inception_input, is_training=False)
embeddings = end_points['PreLogits']
inception_embeddings = tf.squeeze(embeddings, [1, 2], name='SpatialSqueeze')
return inception_embeddings
def load_inception_checkpoint(sess, vars_to_restore, checkpoint_path=None):
"""Loal inception checkpoint to session."""
saver = tf.train.Saver(vars_to_restore)
if checkpoint_path is None:
checkpoint_dir = tempfile.mkdtemp()
try:
checkpoint_tmp = os.path.join(checkpoint_dir, 'checkpoint')
with tf.gfile.Open(INCEPTION_V3_CHECKPOINT, 'r') as f_in, tf.gfile.Open(checkpoint_tmp, 'w') as f_out:
f_out.write(f_in.read())
saver.restore(sess, checkpoint_tmp)
finally:
shutil.rmtree(checkpoint_dir)
else:
saver.restore(sess, checkpoint_path)
In [4]:
# Extract vocabs, images files, captions that are only related to cats and dogs.
from collections import Counter
import six
# If empty, all data is included. Otherwise, include only images with any of the words in its captions.
KEYWORDS = {'cat', 'cats', 'kitten', 'kittens', 'dog', 'dogs', 'puppy', 'puppies'}
# Sentence start, sentence end, and unknown word.
CONTROL_WORDS = ['<s>', '</s>', '<unk>']
def extract(train_content, val_content):
"""Extract vocab, captions, and image files from raw data.
Returns:
A tuple of the following
- Vocab: in the form of word_to_id dict.
- id_wids: A dictionary with key an id, and value a list of captions, where each caption is
represented by a list of word ids.
- id_imagefiles: A dictionary with key an id, and value a path of image file.
"""
id_captions = [(x['image_id'], x['caption']) for x in train_content['annotations']]
id_captions += [(x['image_id'], x['caption']) for x in val_content['annotations']]
id_captions = [(k, v.replace('.', '').replace(',', '').lower().split()) for k, v in id_captions]
# key - id, value - a list of captions
id_captions_filtered = {}
for x in id_captions:
if not KEYWORDS or (KEYWORDS & set(x[1])):
id_captions_filtered.setdefault(x[0], []).append(x[1])
print('number of captions is %d' % sum(len(x) for x in id_captions_filtered.values()))
words = [w for captions in id_captions_filtered.values() for caption in captions for w in caption]
counts = Counter(words)
counts = [x for x in counts.items() if x[1] > 5]
counts = sorted(counts, key=lambda x: (x[1]), reverse=True)
counts += [(x, 0) for x in CONTROL_WORDS]
word_to_id = {str(word_cnt_pair[0]): idx for idx, word_cnt_pair in enumerate(counts)}
print('vocab size is %d' % len(word_to_id))
id_wids = {}
for k, v in six.iteritems(id_captions_filtered):
sentences = []
for caption in v:
wids = [word_to_id[x] if x in word_to_id else word_to_id['<unk>'] for x in caption]
wids = [word_to_id['<s>']] + wids + [word_to_id['</s>']]
sentences.append(wids)
id_wids[k] = sentences
id_imagefiles = {x['id']: x['file_name'] for x in train_content['images']}
id_imagefiles.update({x['id']: x['file_name'] for x in val_content['images']})
id_imagefiles_filtered = {k: v for k, v in six.iteritems(id_imagefiles) if k in id_wids}
print('number of images is %d' % len(id_imagefiles_filtered))
return word_to_id, id_wids, id_imagefiles_filtered
In [5]:
# Load data from files.
import json
with open('/content/datalab/img2txt/captions_val2014.json', 'r') as f:
val_content = json.load(f)
with open('/content/datalab/img2txt/captions_train2014.json', 'r') as f:
train_content = json.load(f)
word_to_id, id_wids, id_imagefiles = extract(train_content, val_content)
In [6]:
# Save the vocab so we can convert word ids to words in prediction.
save_vocab(word_to_id, '/content/datalab/img2txt/vocab.yaml')
In [7]:
def transform(id_imagefiles, id_wids, image_dir, output_dir, train_filename, eval_filename, test_filename, batch_size):
"""Convert images into embeddings, join with captions by id, splits results into train/eval/test,
and save to tf SequenceExample file.
Note that train/eval data will be SequenceExample, but test data will be text
(a list of image file paths) because the final model expects raw images as input.
"""
def _int64_feature(value):
"""Wrapper for inserting an int64 Feature into a SequenceExample proto."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _float_feature(value):
"""Wrapper for inserting an int64 Feature into a SequenceExample proto."""
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature_list(values):
"""Wrapper for inserting an int64 FeatureList into a SequenceExample proto."""
return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])
tf.gfile.MakeDirs(output_dir)
g = tf.Graph()
with g.as_default():
image_str_tensor = tf.placeholder(tf.string, shape=None)
inception_embeddings = build_image_processing(image_str_tensor)
vars_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=INCEPTION_EXCLUDED_VARIABLES)
with tf.Session(graph=g) as sess:
load_inception_checkpoint(sess, vars_to_restore)
# Write to tf.example files.
train_file = os.path.join(output_dir, train_filename)
eval_file = os.path.join(output_dir, eval_filename)
writer_train = tf.python_io.TFRecordWriter(train_file)
writer_eval = tf.python_io.TFRecordWriter(eval_file)
writer_test = tf.gfile.Open(os.path.join(output_dir, test_filename), 'w')
batches = make_batches(list(six.iteritems(id_imagefiles)), batch_size)
num_of_batches = len(id_imagefiles) / batch_size + 1
for batch_num, b in enumerate(batches):
start = datetime.now()
image_bytes = []
for img in b:
with tf.gfile.Open(os.path.join(image_dir, img[1]), 'r') as f:
image_bytes.append(f.read())
embs = sess.run(inception_embeddings, feed_dict={image_str_tensor: image_bytes})
for img, emb in zip(b, embs):
rnd_num = randint(0, 100)
# 5% eval, 5% test, 90% training
if rnd_num > 4:
writer = writer_train if rnd_num > 9 else writer_eval
img_id = img[0]
for caption_wids in id_wids[img_id]:
context = tf.train.Features(feature={"id": _int64_feature(img_id), "emb": _float_feature(emb.tolist())})
feature_lists = tf.train.FeatureLists(feature_list={"wids": _int64_feature_list(caption_wids)})
sequence_example = tf.train.SequenceExample(context=context, feature_lists=feature_lists)
writer.write(sequence_example.SerializeToString())
else:
writer_test.write('%d:%s\n' % (img[0], img[1]))
elapsed = datetime.now() - start
print('processed batch %d of %d in %s' % (batch_num, num_of_batches, str(elapsed)))
writer_train.close()
writer_eval.close()
writer_test.close()
In [14]:
transform(
id_imagefiles,
id_wids,
image_dir='/content/datalab/img2txt/images',
output_dir='/content/datalab/img2txt/transformed',
train_filename='train',
eval_filename='eval',
test_filename='test.txt',
batch_size=500)
In [16]:
!ls /content/datalab/img2txt/transformed -l -h
In [24]:
import tensorflow as tf
def parse_sequence_example(serialized):
"""Parses a tensorflow.SequenceExample into an image and caption.
Args:
serialized: A scalar string Tensor; a single serialized SequenceExample.
Returns:
id: a scalar integer Tensor.
emb: image embeddings, a 1-D Tensor with shape [2048].
wids: word ids, a 1-D Tensor with shape [None].
"""
context, sequence = tf.parse_single_sequence_example(
serialized,
context_features={
'id': tf.FixedLenFeature([], dtype=tf.int64),
'emb': tf.FixedLenFeature([2048], dtype=tf.float32)
},
sequence_features={
'wids': tf.FixedLenSequenceFeature([], dtype=tf.int64),
})
return context['id'], context['emb'], sequence['wids']
def prefetch_input_data(file_pattern, batch_size):
"""Prefetches string values from disk vocab_idvocab_idinto an input queue.
Args:
file_pattern: file patterns (e.g. /tmp/train_data-?????-of-00100).
batch_size: Model batch size used to determine queue capacity.
Returns:
A Queue containing prefetched string values.
"""
data_files = tf.gfile.Glob(file_pattern)
filename_queue = tf.train.string_input_producer(data_files, shuffle=True, capacity=16, name='filename_queue')
capacity = 1000 + 100 * batch_size
values_queue = tf.RandomShuffleQueue(
capacity=capacity,
min_after_dequeue=1000,
dtypes=[tf.string],
name="random_input_queue")
enqueue_ops = []
reader = tf.TFRecordReader()
_, value = reader.read(filename_queue)
enqueue_ops.append(values_queue.enqueue([value]))
tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(values_queue, enqueue_ops))
return values_queue
def build_graph(serialized_sequence_example, vocab_size, train_batch_size, embedding_size, lstm_size, mode):
""" Build the main TensorFlow graph that will be shared by training and evaluation.
"""
uniform_initializer = tf.random_uniform_initializer(minval=-0.08, maxval=0.08)
id, img_emb, wids = parse_sequence_example(serialized_sequence_example)
caption_length = tf.shape(wids)[0]
input_length = tf.expand_dims(tf.subtract(caption_length, 1), 0)
input_seq = tf.slice(wids, [0], input_length)
target_seq = tf.slice(wids, [1], input_length)
indicator = tf.ones(input_length, dtype=tf.int32)
enqueue_list = [[img_emb, input_seq, target_seq, indicator]]
img_embs, input_seqs, target_seqs, input_mask = tf.train.batch_join(
enqueue_list,
batch_size=train_batch_size,
capacity=train_batch_size * 2,
dynamic_pad=True,
name="batch_and_pad")
with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"):
embedding_map = tf.get_variable(
name="map",
shape=[vocab_size, embedding_size], initializer=uniform_initializer)
seq_embeddings = tf.nn.embedding_lookup(embedding_map, input_seqs)
with tf.variable_scope("image_embedding") as scope:
image_embeddings = tf.contrib.layers.fully_connected(
inputs=img_embs,
num_outputs=embedding_size,
activation_fn=None,
weights_initializer=uniform_initializer,
biases_initializer=None,
scope=scope)
lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size, state_is_tuple=True)
if mode == 'train':
lstm_cell = tf.contrib.rnn.DropoutWrapper(lstm_cell, input_keep_prob=0.7, output_keep_prob=0.7)
with tf.variable_scope("lstm", initializer=tf.random_uniform_initializer(minval=-0.08, maxval=0.08)) as lstm_scope:
zero_state = lstm_cell.zero_state(batch_size=image_embeddings.get_shape()[0], dtype=tf.float32)
# Use image_embeddings as initial state.
_, initial_state = lstm_cell(image_embeddings, zero_state)
lstm_scope.reuse_variables()
sequence_length = tf.reduce_sum(input_mask, 1)
lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell,
inputs=seq_embeddings,
sequence_length=sequence_length,
initial_state=initial_state,
dtype=tf.float32,
scope=lstm_scope)
# lstm_outputs's dim is [batch_size, max_seq_length, lstm_cell.output_size]
# Reshape it to 2D Tensor [batch * max_seq_length, lstm_cell.output_size] for loss computation.
lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])
with tf.variable_scope("logits") as logits_scope:
logits = tf.contrib.layers.fully_connected(
inputs=lstm_outputs,
num_outputs=vocab_size,
activation_fn=None,
weights_initializer=uniform_initializer,
scope=logits_scope)
# Similarly, reshape targets to [batch * max_seq_length]
targets = tf.reshape(target_seqs, [-1])
weights = tf.to_float(tf.reshape(input_mask, [-1]))
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, logits=logits)
batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, weights)), tf.reduce_sum(weights), name="batch_loss")
tf.summary.scalar("losses/batch_loss", batch_loss)
global_step = tf.Variable(
initial_value=0,
name="global_step",
trainable=False,
collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])
return batch_loss, losses, weights, global_step
In [82]:
def train_graph(vocab_size, train_batch_size, training_file_pattern, embedding_size=1024, lstm_size=512):
"""Build the training graph."""
train_instances_size = get_instances_size(training_file_pattern)
g = tf.Graph()
with g.as_default():
input_queue = prefetch_input_data(training_file_pattern, batch_size=train_batch_size)
serialized_sequence_example = input_queue.dequeue()
total_loss, _, _, global_step = build_graph(
serialized_sequence_example, vocab_size, train_batch_size, embedding_size, lstm_size, 'train')
learning_rate = tf.constant(2.0) # initial_learning_rate
learning_rate_decay_factor = 0.5
num_batches_per_epoch = (train_instances_size / train_batch_size)
decay_steps = int(num_batches_per_epoch * 8) # num_epochs_per_decay
def _learning_rate_decay_fn(learning_rate, global_step):
return tf.train.exponential_decay(
learning_rate,
global_step,
decay_steps=decay_steps,
decay_rate=learning_rate_decay_factor,
staircase=True)
train_op = tf.contrib.layers.optimize_loss(
loss=total_loss,
global_step=global_step,
learning_rate=learning_rate,
optimizer='SGD',
clip_gradients=5.0,
learning_rate_decay_fn=_learning_rate_decay_fn)
saver = tf.train.Saver(max_to_keep=5)
return g, train_op, global_step, saver
In [26]:
# Remove previous trained model
!rm -r -f /content/datalab/img2txt/train
In [27]:
vocab = load_vocab('/content/datalab/img2txt/vocab.yaml')
vocab_size = len(vocab)
graph, train_op, global_step, saver = train_graph(
vocab_size,
train_batch_size=64,
training_file_pattern='/content/datalab/img2txt/transformed/train')
tf.contrib.slim.learning.train(
train_op,
'/content/datalab/img2txt/train',
log_every_n_steps=100,
graph=graph,
global_step=global_step,
number_of_steps=10000,
saver=saver)
# Save inception checkpoint with the model.
inception_checkpoint = os.path.join('/content/datalab/img2txt/train', 'inception_checkpoint')
with tf.gfile.Open(INCEPTION_V3_CHECKPOINT, 'r') as f_in, tf.gfile.Open(inception_checkpoint, 'w') as f_out:
f_out.write(f_in.read())
In [28]:
from google.datalab.ml import Summary
summary = Summary('/content/datalab/img2txt/train')
summary.list_events()
Out[28]:
In [29]:
summary.plot('losses/batch_loss')
In [80]:
import math
import numpy as np
def eval_graph(vocab_size, eval_batch_size, eval_file_pattern, embedding_size=1024, lstm_size=512):
"""Build evaluation graph."""
g = tf.Graph()
with g.as_default():
input_queue = prefetch_input_data(eval_file_pattern, batch_size=eval_batch_size)
serialized_sequence_example = input_queue.dequeue()
_, losses, weights, global_step = build_graph(serialized_sequence_example, vocab_size, eval_batch_size, embedding_size, lstm_size, 'eval')
saver = tf.train.Saver()
return g, losses, weights, global_step, saver
def eval_model(vocab_size, train_dir, eval_file_pattern, eval_batch_size=64):
"""Evaluate a trained model with evaluation data."""
eval_instances_size = get_instances_size(eval_file_pattern)
graph, losses, weights, global_step, saver = eval_graph(vocab_size, eval_batch_size=64, eval_file_pattern=eval_file_pattern)
checkpoint = tf.train.latest_checkpoint(train_dir)
with tf.Session(graph=graph) as sess:
saver.restore(sess, checkpoint)
global_step_val = tf.train.global_step(sess, global_step.name)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
num_eval_batches = int(math.ceil(eval_instances_size / eval_batch_size))
sum_losses = 0.
sum_weights = 0.
for i in xrange(num_eval_batches):
losses_val, weights_val = sess.run([losses, weights])
sum_losses += np.sum(losses_val * weights_val)
sum_weights += np.sum(weights_val)
if i % 10 == 0:
tf.logging.info("Computed losses for %d of %d batches.", i + 1, num_eval_batches)
perplexity = math.exp(sum_losses / sum_weights)
tf.logging.info("Perplexity = %f", perplexity)
tf.logging.info("Finished processing evaluation at global step %d.", global_step_val)
coord.request_stop()
coord.join(threads, stop_grace_period_secs=10)
In [81]:
eval_model(vocab_size, '/content/datalab/img2txt/train', '/content/datalab/img2txt/transformed/eval')
The prediction graph is mostly similar to train/eval graph, and they share all variables. The difference between them are:
In [46]:
import tensorflow as tf
def predict_graph(vocab_size, embedding_size=1024, lstm_size=512):
g = tf.Graph()
with g.as_default():
image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
input_feed = tf.placeholder(dtype=tf.int64, shape=[None], name="input_feed")
images = tf.expand_dims(image_feed, 0)
input_seqs = tf.expand_dims(input_feed, 1)
inception_embeddings = build_image_processing(images)
inception_vars = tf.contrib.slim.get_variables_to_restore(exclude=INCEPTION_EXCLUDED_VARIABLES)
with tf.variable_scope("seq_embedding"):
embedding_map = tf.get_variable(
name="map",
shape=[vocab_size, embedding_size])
seq_embeddings = tf.nn.embedding_lookup(embedding_map, input_seqs)
with tf.variable_scope("image_embedding") as scope:
image_embeddings = tf.contrib.layers.fully_connected(
inputs=inception_embeddings,
num_outputs=embedding_size,
activation_fn=None,
biases_initializer=None,
scope=scope)
lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size, state_is_tuple=True)
with tf.variable_scope("lstm") as lstm_scope:
zero_state = lstm_cell.zero_state(batch_size=image_embeddings.get_shape()[0], dtype=tf.float32)
_, initial_state = lstm_cell(image_embeddings, zero_state)
initial_state = tf.concat(axis=1, values=initial_state)
lstm_scope.reuse_variables()
tf.concat(axis=1, values=initial_state, name="initial_state")
state_feed = tf.placeholder(dtype=tf.float32, shape=[None, sum(lstm_cell.state_size)], name="state_feed")
state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1)
lstm_outputs, state_tuple = lstm_cell(inputs=tf.squeeze(seq_embeddings, axis=[1]), state=state_tuple)
lstm_state = tf.concat(axis=1, values=state_tuple, name="state")
with tf.variable_scope("logits") as logits_scope:
logits = tf.contrib.layers.fully_connected(
inputs=lstm_outputs,
num_outputs=vocab_size,
activation_fn=None,
scope=logits_scope)
softmax = tf.nn.softmax(logits, name="softmax")
trainable_vars = tf.contrib.slim.get_variables_to_restore(exclude=['InceptionV3/*'])
return g, image_feed, input_feed, state_feed, initial_state, lstm_state, softmax, inception_vars, trainable_vars
We also need a "beam search": on one extreme, if max_caption_length is 20, we will have vocab^20 results and we will pick the one with greatest probs; on the other extreme, we pick only the top word for each step, and there will be only one result, which may not be the one with greatest probs. "Beam search" keeps track of top n paths for each step, and the final results will also be n predictions.
In [70]:
import heapq
import math
import numpy as np
class Caption(object):
"""Represents a complete or partial caption."""
def __init__(self, sentence, state, logprob, score, metadata=None):
"""Initializes the Caption.
Args:
sentence: List of word ids in the caption.
state: Model state after generating the previous word.
logprob: Log-probability of the caption.
score: Score of the caption.
"""
self.sentence = sentence
self.state = state
self.logprob = logprob
self.score = score
def __cmp__(self, other):
"""Compares Captions by score."""
assert isinstance(other, Caption)
if self.score == other.score:
return 0
elif self.score < other.score:
return -1
else:
return 1
# For Python 3 compatibility (__cmp__ is deprecated).
def __lt__(self, other):
assert isinstance(other, Caption)
return self.score < other.score
# Also for Python 3 compatibility.
def __eq__(self, other):
assert isinstance(other, Caption)
return self.score == other.score
class TopN(object):
"""Maintains the top n elements of an incrementally provided set."""
def __init__(self, n):
self._n = n
self._data = []
def size(self):
assert self._data is not None
return len(self._data)
def push(self, x):
"""Pushes a new element."""
assert self._data is not None
if len(self._data) < self._n:
heapq.heappush(self._data, x)
else:
heapq.heappushpop(self._data, x)
def extract(self, sort=False):
"""Extracts all elements from the TopN. This is a destructive operation.
The only method that can be called immediately after extract() is reset().
Args:
sort: Whether to return the elements in descending sorted order.
Returns:
A list of data; the top n elements provided to the set.
"""
assert self._data is not None
data = self._data
self._data = None
if sort:
data.sort(reverse=True)
return data
def reset(self):
"""Returns the TopN to an empty state."""
self._data = []
In [90]:
from PIL import Image
from IPython.display import display
import numpy as np
import math
import os
class ShowAndTellModel(object):
def __init__(self, train_dir, vocab_file, max_caption_length=20, beam_size=5):
self._vocab = load_vocab(vocab_file)
self._train_dir = train_dir
self._max_caption_length = max_caption_length
self._beam_size = beam_size
def __enter__(self):
self._graph, self._image_feed, self._input_feed, self._state_feed, self._initial_state, \
self._lstm_state, self._softmax, inception_vars, trainable_vars = predict_graph(len(self._vocab))
self._sess = tf.Session(graph=self._graph)
inception_checkpoint = os.path.join(self._train_dir, 'inception_checkpoint')
load_inception_checkpoint(self._sess, inception_vars, inception_checkpoint)
saver = tf.train.Saver(trainable_vars)
checkpoint_path = tf.train.latest_checkpoint(self._train_dir)
saver.restore(self._sess, checkpoint_path)
return self
def __exit__(self, *args):
self._sess.close()
def _process_results(self, captions):
id_to_word = {v: k for k, v in six.iteritems(self._vocab)}
for caption in captions:
words = [id_to_word[x] for x in caption.sentence]
words = filter(lambda x: x not in ['<s>', '</s>'], words)
yield ' '.join(words)
def _predict(self, img_file):
with tf.gfile.GFile(img_file, 'r') as f:
image_bytes = f.read()
init_state = self._sess.run(self._initial_state, feed_dict={self._image_feed: image_bytes})
initial_beam = Caption(sentence=[self._vocab['<s>']], state=init_state[0], logprob=0.0, score=0.0)
partial_captions = TopN(self._beam_size)
partial_captions.push(initial_beam)
complete_captions = TopN(self._beam_size)
# Run beam search.
for _ in range(self._max_caption_length - 1):
partial_captions_list = partial_captions.extract()
partial_captions.reset()
input_feed_val = np.array([c.sentence[-1] for c in partial_captions_list])
state_feed_val = np.array([c.state for c in partial_captions_list])
softmax_val, new_states = self._sess.run([self._softmax, self._lstm_state],
feed_dict={self._input_feed: input_feed_val, self._state_feed: state_feed_val})
for i, partial_caption in enumerate(partial_captions_list):
word_probabilities = softmax_val[i]
state = new_states[i]
# For this partial caption, get the beam_size most probable next words.
words_and_probs = list(enumerate(word_probabilities))
words_and_probs.sort(key=lambda x: -x[1])
words_and_probs = words_and_probs[0:self._beam_size]
# Each next word gives a new partial caption.
for w, p in words_and_probs:
if p < 1e-12:
continue # Avoid log(0).
sentence = partial_caption.sentence + [w]
logprob = partial_caption.logprob + math.log(p)
score = logprob
if w == self._vocab['</s>']:
beam = Caption(sentence, state, logprob, score, None)
complete_captions.push(beam)
else:
beam = Caption(sentence, state, logprob, score, None)
partial_captions.push(beam)
if partial_captions.size() == 0:
# We have run out of partial candidates; happens when beam_size = 1.
break
# If we have no complete captions then fall back to the partial captions.
# But never output a mixture of complete and partial captions because a
# partial caption could have a higher score than all the complete captions.
if not complete_captions.size():
complete_captions = partial_captions
return complete_captions.extract(sort=True)
def show_and_tell(self, image_file):
with tf.gfile.GFile(image_file) as f:
img = Image.open(f)
img.thumbnail((299, 299), Image.ANTIALIAS)
display(img)
c = self._predict(image_file)
for r in self._process_results(c):
print(r)
Pick the first 10 instances from test file.
In [75]:
!head /content/datalab/img2txt/transformed/test.txt
In [91]:
with ShowAndTellModel(train_dir='/content/datalab/img2txt/train',
vocab_file='/content/datalab/img2txt/vocab.yaml') as m:
m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000165854.jpg')
m.show_and_tell('/content/datalab/img2txt/images/COCO_val2014_000000524382.jpg')
m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000524476.jpg')
m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000491728.jpg')
m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000033111.jpg')
m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000344127.jpg')
m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000169365.jpg')
m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000098732.jpg')
m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000557508.jpg')
m.show_and_tell('/content/datalab/img2txt/images/COCO_train2014_000000492030.jpg')
For fun, I would give it a try on pictures of my cats!
In [92]:
with ShowAndTellModel(train_dir='/content/datalab/img2txt/train',
vocab_file='/content/datalab/img2txt/vocab.yaml') as m:
m.show_and_tell('gs://bradley-sample-notebook-data/chopin_vivaldi.jpg')
m.show_and_tell('gs://bradley-sample-notebook-data/vivaldi_chopin_tail.jpg')
m.show_and_tell('gs://bradley-sample-notebook-data/vivaldi.jpg')
In [ ]: